#!/usr/bin/env python3

# Usage: ~/recording_ws/src/record_social_dining/extract_all_bags_audios.py


from ros import rosbag
import wave
import time
import glob


def extract_audio(bag_path, topic_name, output_path):
    bag = rosbag.Bag(bag_path)
    f = wave.open(output_path, 'wb')
    
    # should match the info in respeaker node
    f.setframerate(16000)
    f.setsampwidth(2)
    f.setnchannels(1)
    msg_count = 0
    for topic, msg, stamp in bag.read_messages(topics=[topic_name]):
        if msg._type == 'audio_common_msgs/AudioData':
            msg_count += 1
            f.writeframes(msg.data)
    bag.close()
    f.close()


if __name__ == '__main__':

    # can extract other 
    topic_name = '/audio'
    # input_dir = '/home/emprise/recording_ws/src/record_social_dining/bag'
    input_dir = '/media/emprise/FCF4C425F4C3E04E/Janko-SocialDining_Backup/bag'
    output_dir = '/home/emprise/recording_ws/src/record_social_dining/extracted_audio'

    bag_filepaths = sorted(glob.glob(f'{input_dir}/*.bag'))
    start_time = time.time()

    for bag_filepath in bag_filepaths:
        session_id = bag_filepath.split('/')[-1].split('.')[0]
        output_path = f'{output_dir}/{session_id}.wav'
        extract_audio(bag_filepath, topic_name, output_path)

    print(f"Extracted wav audios from {len(bag_filepaths)} bag files.")
    print(f"Time taken (seconds): {time.time() - start_time}")
